{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# seq2seq模型测试\n",
    "---\n",
    "\n",
    "数据集构建方案不同，使用更复杂的模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "import random\n",
    "import pprint\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "sys.path.insert(0, \"/home/team55/notespace/zengbin\")\n",
    "\n",
    "import jddc.utils as u\n",
    "import jddc.datasets as d\n",
    "from seq2seq.fields import *\n",
    "from seq2seq.optim import Optimizer\n",
    "from seq2seq.models import EncoderRNN, DecoderRNN, Seq2seq\n",
    "from seq2seq.loss import NLLLoss\n",
    "from seq2seq.supervised_trainer import SupervisedTrainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参数配置\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2SeqConfig(object):\n",
    "    \"\"\"Seq2Seq模型参数配置\"\"\"\n",
    "    use_cuda = True\n",
    "    device = 1\n",
    "    teacher_forcing_ratio = 0.5\n",
    "\n",
    "    # encoder & decoder\n",
    "    hidden_size = 256\n",
    "    n_layers = 4\n",
    "    bidirectional = True\n",
    "    max_len = 300\n",
    "    rnn_cell = 'lstm'\n",
    "\n",
    "    encoder_params = u.AttrDict()\n",
    "    encoder_params['hidden_size'] = hidden_size\n",
    "    encoder_params['n_layers'] = n_layers\n",
    "    encoder_params['bidirectional'] = bidirectional\n",
    "    encoder_params['max_len'] = max_len\n",
    "    encoder_params['rnn_cell'] = rnn_cell\n",
    "    encoder_params['variable_lengths'] = True\n",
    "    encoder_params['input_dropout_p'] = 0\n",
    "    encoder_params['dropout_p'] = 0.5\n",
    "\n",
    "    decoder_params = u.AttrDict()\n",
    "    decoder_params['hidden_size'] = hidden_size*2 if bidirectional else hidden_size\n",
    "    decoder_params['n_layers'] = n_layers\n",
    "    decoder_params['bidirectional'] = bidirectional\n",
    "    decoder_params['max_len'] = max_len\n",
    "    decoder_params['rnn_cell'] = rnn_cell\n",
    "    decoder_params['use_attention'] = True\n",
    "    decoder_params['device'] = device\n",
    "    decoder_params['input_dropout_p'] = 0\n",
    "    decoder_params['dropout_p'] = 0.5\n",
    "\n",
    "    def __init__(self):\n",
    "        # 模型存储目录\n",
    "        self.s2s_path = os.path.join(\"/home/team55/notespace/data\", \"seq2seq02\")\n",
    "        u.insure_folder_exists(self.s2s_path)\n",
    "        self.file_train = os.path.join(self.s2s_path, \"train.tsv\")\n",
    "        # 翻转QQ分词结果\n",
    "        self.file_train_rq = os.path.join(self.s2s_path, \"train_reverse_q.tsv\")\n",
    "        self.log_file = os.path.join(self.s2s_path, \"seq2seq_02.log\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf = Seq2SeqConfig()\n",
    "logger = u.create_logger(conf.log_file, name=\"s2s\", cmd=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# solution for  _csv.Error: field larger than field limit (131072)\n",
    "import csv\n",
    "csv.field_size_limit(500 * 1024 * 1024)\n",
    "\n",
    "src = SourceField(batch_first=True)\n",
    "tgt = TargetField(batch_first=True)\n",
    "max_len = conf.max_len\n",
    "\n",
    "def len_filter(example):\n",
    "    return len(example.src) <= max_len and len(example.tgt) <= max_len\n",
    "\n",
    "train = torchtext.data.TabularDataset(\n",
    "    path=conf.file_train_rq, format='tsv',\n",
    "    fields=[('src', src), ('tgt', tgt)],\n",
    "    filter_pred=len_filter\n",
    ")\n",
    "\n",
    "src.build_vocab(train, max_size=100000)\n",
    "tgt.build_vocab(train, max_size=100000)\n",
    "input_vocab = src.vocab\n",
    "output_vocab = tgt.vocab"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建模型\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(conf.encoder_params)\n",
    "print(conf.decoder_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = NLLLoss()\n",
    "encoder = EncoderRNN(vocab_size=len(src.vocab), **conf.encoder_params)\n",
    "decoder = DecoderRNN(vocab_size=len(tgt.vocab), eos_id=tgt.eos_id, sos_id=tgt.sos_id, **conf.decoder_params)\n",
    "seq2seq = Seq2seq(encoder, decoder)\n",
    "\n",
    "device = conf.device\n",
    "if conf.use_cuda:\n",
    "    seq2seq.cuda(device)\n",
    "    loss.cuda(device)\n",
    "\n",
    "for param in seq2seq.parameters():\n",
    "    param.data.uniform_(-0.08, 0.08)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimizer and learning rate scheduler can be customized by\n",
    "# # explicitly constructing the objects and pass to the trainer.\n",
    "optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)\n",
    "# scheduler = StepLR(optimizer.optimizer, 1)\n",
    "# optimizer.set_scheduler(scheduler)\n",
    "# train\n",
    "trainer = SupervisedTrainer(loss=loss, batch_size=32, checkpoint_every=500, print_every=10,\n",
    "                            expt_dir=conf.s2s_path, random_seed=\"1234\", device=conf.device)\n",
    "trainer.logger = logger\n",
    "seq2seq = trainer.train(seq2seq, train, num_epochs=3, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python3.6",
   "language": "python",
   "name": "python3.6"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
